from typing import Any, Callable, Union, Tuple, Sequence, Optional
from .. import Tensor
from .grad_mode import no_grad as no_grad, enable_grad as enable_grad, \
    set_grad_enabled as set_grad_enabled

# TODO make Variable and Function more precise
class Variable:
    ...

class Function:
    @staticmethod
    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: ...
    @staticmethod
    def backward(ctx: Any, *grad_outputs: Any) -> Any: ...

class NestedIOFunction(Function):
    # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
    # superclass (Function) but are instance methods here, which mypy reports as incomptabile.
    def backward(self, *gradients: Any) -> Any: ...  # type: ignore
    def forward(self, *args: Any) -> tuple: ...  # type: ignore
    def save_for_backward(self, *args: Any) -> None:...
    def mark_dirty(self, *args: Any, **kwargs: Any) -> None:...
    def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: ...
    def forward_extended(self, *input: Any) -> None:...
    def backward_extended(self, *grad_output: Any) -> None: ...

# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callabe can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., raise_exception: bool=..., check_sparse_nnz: bool=...) -> bool: ...
def gradgradcheck(func: Callable[..., Union[Tensor, Tuple[Tensor, ...]]], inputs: Union[Tensor, Tuple[Tensor, ...]], eps: float=..., atol: float=..., rtol: float=..., gen_non_contig_grad_outputs: bool=..., raise_exception: bool=...) -> bool: ...

class detect_anomaly:
    def __enter__(self) -> None: ...
    def __exit__(self, *args: Any) -> bool: ...

class set_detect_anomaly:
    def __init__(self, mode: bool) -> None: ...
    def __enter__(self) -> None:...
    def __exit__(self, *args: Any) -> bool: ...

_TensorOrTensors = Union[Tensor, Sequence[Tensor]]
def backward(tensors: _TensorOrTensors, grad_tensors: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=...) -> None: ...
def grad(outputs: _TensorOrTensors, inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors]=..., retain_graph: Optional[bool]=..., create_graph: bool=..., only_inputs: bool=..., allow_unused: bool=...) -> Tuple[Tensor, ...]: ...